# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference decorators tests."""

import inspect

import numpy as np
import pytest

from pytriton.decorators import TritonContext, batch
from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig
from pytriton.models.model import _inject_triton_context
from pytriton.proxy.types import Request

_sample_requests = [
    Request({"b": np.array([[7, 5], [8, 6]]), "a": np.array([[1], [1]])}, {}),
    Request({"b": np.array([[1, 2], [1, 2], [11, 12]]), "a": np.array([[1], [1], [1]])}, {}),
    Request({"b": np.array([[1, 2]]), "a": np.array([[1]])}, {}),
]

_sample_requests_multiplied_by_2 = [
    {"b": np.array([[14, 10], [16, 12]]), "a": np.array([[2], [2]])},
    {"b": np.array([[2, 4], [2, 4], [22, 24]]), "a": np.array([[2], [2], [2]])},
    {"b": np.array([[2, 4]]), "a": np.array([[2]])},
]


@batch
def batched_multiply_2(**_inputs):
    assert isinstance(_inputs, dict)
    return {key: value * 2 for key, value in _inputs.items()}


@batch
def batched_multiply_2_gen(**_inputs):
    assert isinstance(_inputs, dict)
    yield {key: value * 2 for key, value in _inputs.items()}


@pytest.mark.parametrize(
    "inputs, infer_fn, expected",
    (
        (
            _sample_requests,
            batched_multiply_2,
            _sample_requests_multiplied_by_2,
        ),
        (
            _sample_requests,
            batched_multiply_2_gen,
            _sample_requests_multiplied_by_2,
        ),
    ),
)
def test_batch(inputs, infer_fn, expected):
    results = infer_fn(inputs)

    if inspect.isgenerator(results):
        # each item generated by batch function is a list of results for each request
        # thus we need to flatten the list of lists
        results = [item for partial_results in results for item in partial_results]

    assert len(expected) == len(results)  # ensure same number of results as expected
    for expected_result, result in zip(expected, results):
        assert list(expected_result) == list(result)  # ensure same keys
        assert all(np.equal(expected_result[key], result[key]).all() for key in expected_result.keys())


@batch
def batched_multiply_2_returning_list(**_inputs):
    assert isinstance(_inputs, dict)
    return [value * 2 for value in _inputs.values()]


@batch
def batched_multiply_2_returning_list_gen(**_inputs):
    assert isinstance(_inputs, dict)
    yield [value * 2 for value in _inputs.values()]


def _prepare_and_inject_context_with_config(config, fn):
    context = TritonContext()
    context.model_configs[fn] = config
    _inject_triton_context(context, fn)
    return context


@pytest.mark.parametrize(
    "inputs, infer_fn, expected",
    (
        (
            _sample_requests,
            batched_multiply_2_returning_list,
            _sample_requests_multiplied_by_2,
        ),
        (
            _sample_requests,
            batched_multiply_2_returning_list_gen,
            _sample_requests_multiplied_by_2,
        ),
    ),
)
def test_batch_with_context(inputs, infer_fn, expected):
    # list outputs require the context to be injected
    _prepare_and_inject_context_with_config(
        config=TritonModelConfig(
            "my_model",
            inputs=[TensorSpec(name, value.shape, value.dtype) for name, value in inputs[0].items()],
            outputs=[TensorSpec(name, value.shape, value.dtype) for name, value in expected[0].items()],
        ),
        fn=infer_fn,
    )

    test_batch(inputs, infer_fn, expected)


def test_batch_raises_on_incorrect_batch_size_of_outputs():
    @batch
    def _infer_fn(**_inputs):
        return {key: value[:1] * 2 for key, value in _inputs.items()}

    with pytest.raises(ValueError, match=r"Received output tensors with different batch sizes"):
        _infer_fn(_sample_requests)
